package com.garbagecode.resultbounds;

import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.ResultMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.session.Configuration;

/**
 * 
 * 用于创建 MyBatis 中的类，比如 MappedStatement、SqlSource、ResultMap
 *
 */
public class MyBatisClassCreator {
  private static Constructor<MappedStatement> mappedStatementConstructor; // MappedStatement类的构造器
  private static Field mappedStatementSqlSourceField;                     // MappedStatement类的 sqlSource 字段
  private static Field mappedStatementResultMapsField;                    // MappedStatement类的 resultMaps 字段
  private static Field[] mappedStatementOtherFields;                      // MappedStatement类除了 sqlSource、resultMap 外的其它字段

  //
  // 初始上面几个静态字段
  //
  static {
    Class<MappedStatement> clazz = MappedStatement.class;
    Field[] fields = clazz.getDeclaredFields();

    try {
      mappedStatementConstructor = clazz.getDeclaredConstructor();
      mappedStatementConstructor.setAccessible(true);
      
      mappedStatementSqlSourceField = clazz.getDeclaredField("sqlSource");
      mappedStatementSqlSourceField.setAccessible(true);
      mappedStatementResultMapsField = clazz.getDeclaredField("resultMaps");
      mappedStatementResultMapsField.setAccessible(true);

      List<Field> fieldList = new ArrayList<Field>();
      
      for (Field field : fields) {
        if (!field.getName().equals("sqlSource") && !field.getName().equals("resultMaps")) {
          field.setAccessible(true);
          fieldList.add(field);
        }
      }

      mappedStatementOtherFields = fieldList.toArray(new Field[0]);
    } catch (SecurityException | NoSuchMethodException | NoSuchFieldException e) {
      throw new RuntimeException(e);
    }
  }


  /**
   * 返回创建的 MappedStatement 对象
   * @param mappedStatement
   * @param sqlSource
   * @param resultMaps
   * @return
   */
  public MappedStatement createMappedStatement(MappedStatement mappedStatement, 
                                               SqlSource sqlSource, 
                                               List<ResultMap> resultMaps) {
    if (mappedStatement == null) {
      throw new IllegalArgumentException("mappedStatement null");
    }
    
    MappedStatement newMappedStatement = null;

    try {
      newMappedStatement = mappedStatementConstructor.newInstance();
      
      // 新创建的 newMappedStatement 的字段值用参数 mappedStatement 对应
      // 的字段初始(除了 sqlSource 和 resultMap)
      for (Field field : mappedStatementOtherFields) {
        try {
          field.set(newMappedStatement, field.get(mappedStatement));
        } catch (IllegalArgumentException | IllegalAccessException e) {
          throw new RuntimeException(e);
        }
      }
      
      // 判断对象 newMappedStatement 的 sqlSource 和 resultMaps 的字段值
      // 是取自参数 mappedStatement 还是参数 sqlSource 和 resultMaps
      Object newMappedStatementSqlSource = sqlSource != null 
          ? sqlSource 
          : mappedStatementSqlSourceField.get(mappedStatement);
      
      Object newMappedStatementResultMaps = resultMaps != null
          ? resultMaps
          : mappedStatementResultMapsField.get(mappedStatement);
      
      
      mappedStatementSqlSourceField.set(newMappedStatement, newMappedStatementSqlSource);
      mappedStatementResultMapsField.set(newMappedStatement, newMappedStatementResultMaps);
    } catch (InstantiationException | IllegalAccessException | IllegalArgumentException 
        | InvocationTargetException | SecurityException e) {
      throw new RuntimeException(e);
    }

    return newMappedStatement;
  }
  
  /**
   * 返回创建的 ResultMap 列表
   * @param clazz
   * @return
   */
  public List<ResultMap> createResultMaps(Class<?> clazz) {
    if (clazz == null) {
      throw new IllegalArgumentException("clazz null");
    }
    
    //
    // 对于为什么创建  ResultMap 列表的代码是这样子写，
    // 我已经不记得了
    String id = MyBatisClassCreator.class.getName() + "." + clazz.getName();
    List<ResultMapping> resultMappings = new ArrayList<ResultMapping>();

    ResultMap.Builder resultMapBuilder = new ResultMap.Builder(null, id, clazz, resultMappings);
    ResultMap resultMap = resultMapBuilder.build();
    
    List<ResultMap> resultMaps = new ArrayList<ResultMap>();
    resultMaps.add(resultMap);
    resultMaps = Collections.unmodifiableList(resultMaps);
    
    return resultMaps;
  }
  
  /**
   * 返回创建的 SqlSource 对象
   * @param configuration
   * @param sql
   * @param parameterMappings
   * @param parameterObject
   * @return
   */
  public SqlSource createSqlSource(Configuration configuration, 
                              String sql, 
                              List<ParameterMapping> parameterMappings, 
                              Object parameterObject) {
    return new TemporarySqlSource(configuration, sql, parameterMappings, parameterObject);
  }


}

/**
 * 
 * 这是用来创建一个 SqlSource 对象，这个对象里的 SQL
 * 语句是改动过的，在分页拦截里会用到
 */
class TemporarySqlSource implements SqlSource {
  private BoundSql boundSql;

  public TemporarySqlSource(Configuration configuration, 
                            String sql, 
                            List<ParameterMapping> parameterMappings, 
                            Object parameterObject) {
    boundSql = new BoundSql(configuration, sql, parameterMappings, parameterObject);
  }

  @Override
  public BoundSql getBoundSql(Object parameterObject) {
    return boundSql;
  }
}
